import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx
import numpy as np
import os, glob, sys, subprocess
from matplotlib.colors import LinearSegmentedColormap
from math import pi
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

norm = np.linalg.norm

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 8.5
tabcolors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple',
                       'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']


def main():

    k = sys.argv[1]
    t = sys.argv[2]
    fig_name = f"xi_{t}K_k_{k}_scatter"
    t=float(t)

    fig = plt.figure(figsize=(1.7, 3.5), dpi=1200)

    wspace = 0.0
    hspace = 0.0
    labelh = 0.12
    labelw = 0.15

    wsize = (0.95-labelw-wspace)/1.0
    hsize = (0.85-labelh-hspace)/3.0

    axs = [[], [], []]
    for i in range(3):
        for j in range(1):
            ax = fig.add_axes([labelw+wsize*j, (2-i)*hsize+labelh,
                               wsize, hsize])
            axs[i] += [ax]
    axs = np.array(axs)
    axt = [axs[0, 0].twiny()]

    stride = 1
    variable_name = ["$\\xi", "$\\phi$"]
    data1 = np.load(f"{k}_sum.npz")
    pe = [data1['pe'][::stride]* 0.0433634]
    colvar = [data1['colvar'][::stride]]
    xi = [np.hstack(data1['xi'])[::stride]]
    for key in data1:
        print(key, data1[key].shape)
    pemin = -0.85
    pemax = -0.40

    ximin = -20
    ximax = 17

    ffile_path = ["./"]
    ffile_name = [f"{k}_xi_dw_us_uwham.hist", f"{k}_dw_us_uwham.hist", f"{k}_AlaD_re_uwham.hist"]
    fe = [[], []]
    for i in range(1):
        for j in range(3):
            fe[i] += [np.loadtxt(ffile_path[i]+ffile_name[j])]

    for j in range(1):

    #     xi = data['xi'].reshape([-1])
    #     e = data['alle'][:, i+1]*0.043363
    #     intc = data['colvar']

        # c1 = axs[1, j].scatter(colvar[j][:, 0], colvar[j][:, 1],
        #                   c=pe[j], vmin=pemin, vmax=pemax, s=0.1, linewidths=0)
        if t<200:
            vmax=0.1
        else:
            vmax=0.25

        f = fe[j][2][:, 3].reshape([60, 60]).T
        xmin = np.min(fe[j][2][:, 0])
        xmax = np.max(fe[j][2][:, 0])
        ymin = np.min(fe[j][2][:, 1])
        ymax = np.max(fe[j][2][:, 1])
        dx = fe[j][2][1, 0] - fe[j][2][0, 0]
        dy = fe[j][2][1, 1] - fe[j][2][0, 1]

        c1 = axs[1, j].imshow(f, extent=[xmin-dx, xmax+dx, ymin-dy, ymax+dy],
                              vmin=0, vmax=vmax,
                              origin='lower',
                              # interpolation='bicubic',
                              aspect='auto')
        c2 = axs[2, j].scatter(colvar[j][:, 0], colvar[j][:, 1],
                               cmap='inferno',
                          c=xi[j], vmin=ximin, vmax=ximax, s=0.1, linewidths=0)

        c = [c1, c2]
        h = []
        axs[0, j].set_xlim([-pi*0.6, pi*0.5])
        for i in range(1, 3):
            axs[i, j].set_xlim([-pi*0.6, pi*0.5])
            axs[i, j].set_ylim([-pi*0.8, pi*0.8])
            axs[i, j].set_yticks([-pi/2.,  0, pi/2.])
            if j == 0:
                axs[i, j].set_yticklabels(["$-\pi/2$", 0, "$\pi/2$"])
                axs[i, j].set_ylabel('$\\theta$', labelpad=0)
                axs[i, j].set_yticklabels(["", "", ""])
                axins = inset_axes(axs[i, j],
                                  width=0.5,  # width = 5% of parent_bbox width
                                  height=0.05,  # height : 50%
                                  # loc='upper right',
                                  bbox_to_anchor=(0.40, 0.95, 0.5, 0.05),
                                  bbox_transform=axs[i, j].transAxes)
                                  # borderpad=0,
                                  # )
                axins.xaxis.set_ticks_position("bottom")
                h += [fig.colorbar(c[i-1], orientation="horizontal", cax=axins)]
        h[0].set_label("Free Energy (eV)", labelpad=0, fontsize=6)
        # h[0].set_ticks([-0.8, -0.6, -0.4])
        h[1].set_label("$\\xi$", labelpad=0, fontsize=6)
        h[1].set_ticks([-15, 0, 15])
        h[0].ax.tick_params(labelsize=6)
        h[1].ax.tick_params(labelsize=6)


        axt[j].plot(-fe[j][0][:, 0], fe[j][0][:, 2])
        axs[0, j].plot(fe[j][1][:, 0], fe[j][1][:, 2], '--', color=tabcolors[4])
        axs[0, j].set_ylim([0, 0.5])
        print("supposed to be xi", j, np.max(fe[j][0][:, 0]))
        print("supposed to be phi", j, np.max(fe[j][1][:, 0]))

        axs[0, j].set_yticks([0, 0.1, 0.2, 0.3, 0.4])
        if j == 1:
            axs[0, j].set_yticklabels(["", "", "", "", ""])
        else:
            axs[0, j].set_ylabel("Free energy (eV)", labelpad=5)
        axt[j].set_xticks([-15, 0, 15])

        ax = axt[j]
        axt[j].set_xlabel("$-\\xi$")
        axt[j].xaxis.label.set_color(tabcolors[0])
        axt[j].tick_params(axis='x', colors=tabcolors[0])
        axt[j].spines['top'].set_color(tabcolors[0])

        axs[2, j].set_xticks([-pi/2., -pi/4., 0, pi/4., pi/2.])
        axs[1, j].set_xticklabels(["", "", "", "",""])
        axs[0, j].set_xticklabels(["", "", "", "",""])



    axs[2, 0].set_xticklabels(["$-\pi/2$", "$-\pi/4$", 0, "$\pi/4$",""])
    axs[2, 0].set_xlabel('$\\phi$')

    #     ax1.tick_params(axis='y', colors=c)
    #     ax1.set_ylim([-0.1, 0.5])

    #     ax1.set_xlim([xmin[i], xmax[i]])

    #     ax2 = ax1.twiny()
    #     ax2.plot(-(i-0.5)*2*b[:, 0], b[:, 2], '-', linewidth=1,
    #              color=c2, zorder=1)
    #     if i==0:
    #         ax2.set_xlabel("-$\\phi$")
    #         ax2.set_xlim([-pi/2, pi/2])
    #         ax2.set_xticks([-pi/2., -pi/4., 0, pi/4., pi/2.])
    #         ax2.set_xticklabels(["$-\pi/2$", "$-\pi/4$", 0, "$\pi/4$","$\pi/2$"])
    #     else:
    #         ax2.set_xlabel("-$\\xi$")
    #         ax2.set_xlim([-20, 28])
    #     ax2.xaxis.label.set_color(c2)
    #     ax2.tick_params(axis='x', colors=c2)

    labels = ['a', 'b', 'c', 'd', 'e', 'f']
    for i in range(3):
        for j in range(1):
            axs[i, j].text(0.1, 0.95, f"({labels[i*1+j]})", transform=axs[i, j].transAxes,
                          fontsize=8, va='top', ha='right')
    axs[0, 0].text(0.5, 1.55, f"CV=$\\phi$", transform=axs[0, 0].transAxes,
                  fontsize=10, va='top', ha='center')

    axt[0].text(8, 0.3, f"$\\xi$",
                  fontsize=10, va='top', ha='center', color=tabcolors[0])
    axs[0, 0].text(-0.55, 0.3, f"$\\phi$",
                  fontsize=10, va='top', ha='center', color=tabcolors[4])

    fig.savefig(f"{fig_name}.png", dpi=300)

if __name__ == '__main__':
    main()

